In [1]:
import os
os.chdir('/data/l989o/deployed/a')
import sys
if '/data/l989o/a' in sys.path:
    sys.path.remove('/data/l989o/a')    

Studing patients

Computing fractions

In [2]:
import scanpy as sc
import anndata as ad
import numpy as np
import matplotlib.pyplot as plt
import os
import h5py
import pickle
import pandas as pd
from data import file_path
from tqdm.notebook import tqdm
True
<KeysViewHDF5 ['count', 'maximum', 'mean', 'sum', 'variance']>
In [3]:
from data import TransformedMeanDataset
ds0 = TransformedMeanDataset('train')
n = 0
for x in ds0:
    n += len(x)
print(n)

ds1 = TransformedMeanDataset('validation')
v = 0
for x in ds1:
    v += len(x)
print(v)
print(n + v)
449434
220218
669652
In [4]:
class Quantities:
    def __init__(self):
        self.phenographs = dict()
        self.adata = dict()
        self.num_phenograph_classes = dict()
        self.all_fractions = dict()
        self.omes = dict()

q = Quantities()
        
for method in ['raw', 'transformed', 'vae_mu']:
    f = file_path(f'umap_{method}.adata')
    a = ad.read(f)
    print(a)
    a = a[:n].copy()
    print(a)

    index_info_omes, index_info_begins, index_info_ends = pickle.load(open(file_path('merged_cells_info.pickle'), 'rb'))
    print(index_info_ends[-1])

    l = []
    b = file_path(f'phenograph_{method}.hdf5')
    with h5py.File(b, 'r') as f:
        for i, o in enumerate(index_info_omes):
            phenograph = f[o]['phenograph'][...].reshape((-1, 1))
            assert len(phenograph) == index_info_ends[i] - index_info_begins[i]
            l.append(phenograph)
    phenographs = np.concatenate(l, axis=0)
    phenographs.shape
    q.phenographs[method] = phenographs

    s = pd.Series(phenographs.flatten(), dtype='category')
    print(s)
    s.index = a.obs.index

    a.obs['phenograph'] = s
    display(a.obs)
    q.adata[method] = a

    # sc.pl.umap(a, color='phenograph')

    num_phenograph_classes = np.max(phenographs) + 1
    q.num_phenograph_classes[method] = num_phenograph_classes
    print(num_phenograph_classes)
    a = len(set(phenographs.flatten().tolist()))
    assert num_phenograph_classes == a, a

    phenograph_fractions = dict()
    with h5py.File(b, 'r') as f:
        q.omes[method] = list(f.keys())
        for i, o in tqdm(enumerate(index_info_omes), desc='computing fractions'):
            phenograph = f[o]['phenograph'][...].reshape((-1, 1))
            assert len(phenograph) == index_info_ends[i] - index_info_begins[i]
            fractions = np.zeros((num_phenograph_classes))
            for p in phenograph.flatten():
                fractions[p] += 1
            fractions /= np.sum(fractions)
            phenograph_fractions[o] = fractions.reshape((1, -1))
    print(len(phenograph_fractions))
    all_fractions = np.concatenate(list(phenograph_fractions.values()), axis=0)
    print(all_fractions.shape)
    q.all_fractions[method] = all_fractions
AnnData object with n_obs × n_vars = 669652 × 39
    uns: 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'
AnnData object with n_obs × n_vars = 449434 × 39
    uns: 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'
449434
0          6
1          3
2          3
3          1
4         22
          ..
449429    19
449430    12
449431     3
449432     2
449433    19
Length: 449434, dtype: category
Categories (58, int64): [0, 1, 2, 3, ..., 54, 55, 56, 57]
phenograph
0 6
1 3
2 3
3 1
4 22
... ...
449429 19
449430 12
449431 3
449432 2
449433 19

449434 rows × 1 columns

58
226
(226, 58)
AnnData object with n_obs × n_vars = 669652 × 39
    uns: 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'
AnnData object with n_obs × n_vars = 449434 × 39
    uns: 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'
449434
0         0
1         0
2         0
3         0
4         0
         ..
449429    1
449430    8
449431    3
449432    3
449433    1
Length: 449434, dtype: category
Categories (78, int64): [0, 1, 2, 3, ..., 74, 75, 76, 77]
phenograph
0 0
1 0
2 0
3 0
4 0
... ...
449429 1
449430 8
449431 3
449432 3
449433 1

449434 rows × 1 columns

78
226
(226, 78)
AnnData object with n_obs × n_vars = 669652 × 5
    uns: 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'
AnnData object with n_obs × n_vars = 449434 × 5
    uns: 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'
449434
0         16
1         16
2         19
3         17
4          6
          ..
449429    12
449430     6
449431    22
449432    20
449433    11
Length: 449434, dtype: category
Categories (30, int64): [0, 1, 2, 3, ..., 26, 27, 28, 29]
phenograph
0 16
1 16
2 19
3 17
4 6
... ...
449429 12
449430 6
449431 22
449432 20
449433 11

449434 rows × 1 columns

30
226
(226, 30)
In [5]:
q.fractions_u = dict()
q.dbscan_labels = dict()

for method in ['raw', 'transformed', 'vae_mu']:
    print('-' * 100)
    print('method =', method)
    from umap import UMAP
    reducer = UMAP(2, verbose=True)
    u = reducer.fit_transform(q.all_fractions[method])
    q.fractions_u[method] = u

    from sklearn.cluster import DBSCAN
    clustering = DBSCAN(eps=1, min_samples=2).fit(u)
    print(clustering.labels_)
    print(clustering)
    q.dbscan_labels[method] = clustering.labels_

    print(type(clustering.labels_))
    print(clustering.labels_.shape)
----------------------------------------------------------------------------------------------------
method = raw
UMAP(n_neighbors=2, verbose=True)
Construct fuzzy simplicial set
Wed Jan 27 23:10:01 2021 Finding Nearest Neighbors
Wed Jan 27 23:10:03 2021 Finished Nearest Neighbor Search
Wed Jan 27 23:10:04 2021 Construct embedding
	completed  0  /  500 epochs
	completed  50  /  500 epochs
	completed  100  /  500 epochs
	completed  150  /  500 epochs
	completed  200  /  500 epochs
	completed  250  /  500 epochs
	completed  300  /  500 epochs
	completed  350  /  500 epochs
	completed  400  /  500 epochs
	completed  450  /  500 epochs
Wed Jan 27 23:10:05 2021 Finished embedding
[ 0  1  2  3  4  4  5  6  7  8  4  9  4 10 11 12 13  1 14  4 15 10 16 17
  8 16 18 19  4  9 13 20 21 20 22 21  4 17 17 23 24 15 14 16 11 18 25  9
 11  4  4 24  3 15 26 16 27  9 28 29 30  1 28 30 26 23 10 18 28 18  0 18
  9 18 23 22 10 26 16  8  6  7 31 12 32  3 32  6 21 33 12 34 31  7 30 12
 33  6 32 17 14 31 31 35 34 14 12 33 25 12  5 26 17 36 36  5  3 37 38 38
 39 39 39 39 40 40 40 40  5 41 41 41 42 37 42  3 43 25 43 22 34 34 44 45
 46 39  5 45 25 25 44 19 19 19 23 29 29 29 27 46 35 47 47 22 35 35  3 48
 48 48 25 25  5 49  2 13 50 46 46 39 50 24  4 44 44 51 24 51 39 39 51 45
 49 22 45 42 42 37 42 29 29 52 29 19 49 53 53 22 22 24 52 52 39 37 29 29
 49 54 39 39 17 50 50 54 39 46]
DBSCAN(eps=1, min_samples=2)
<class 'numpy.ndarray'>
(226,)
----------------------------------------------------------------------------------------------------
method = transformed
UMAP(n_neighbors=2, verbose=True)
Construct fuzzy simplicial set
Wed Jan 27 23:10:05 2021 Finding Nearest Neighbors
Wed Jan 27 23:10:05 2021 Finished Nearest Neighbor Search
Wed Jan 27 23:10:05 2021 Construct embedding
	completed  0  /  500 epochs
	completed  50  /  500 epochs
	completed  100  /  500 epochs
	completed  150  /  500 epochs
	completed  200  /  500 epochs
	completed  250  /  500 epochs
	completed  300  /  500 epochs
	completed  350  /  500 epochs
	completed  400  /  500 epochs
	completed  450  /  500 epochs
Wed Jan 27 23:10:06 2021 Finished embedding
[ 0  1  2  3  4  5  3  6  7  8  9 10  4 11 12  9 13 13 14 15 16 17 18  0
 19 20  7 13 21 14 22  9 23  3 23 23 20 23 19 24 25 26 17 18 12  7  7 17
 27 28 28 29 29  4 30 18 31 17 30  0  2 11 16 17 32 26 26 30  3 33  8 33
  0  9 34 18  7  9 18 26 35 21 26 33 26  9 25 33 36 32  7  8 14 11 35 37
 19 37 14  0 20 11 27 23 25  3  7  3  7  3 38 39  2 40 40  7  7  7 13 41
  5 42 43 43 44 44 44 44 45 45 45 45 41 32  6 16 46  7 46 15 47  9 13 35
 34 39 38 34  5  5 48 31 10 10 34 47 13  1 18 10 49 41 48 15 49 49  3 50
 50 50  7 12 42 51 47 31 52 17 17 53 39  7  7 51 51 51 39 51 52 52 51 17
 34  5 17 54 41 42 36  9 13 12 53 31 48 26 20  3 18 34 13 10 52 39 39 17
 34 54 33 28 24 26 26 22 28 24]
DBSCAN(eps=1, min_samples=2)
<class 'numpy.ndarray'>
(226,)
----------------------------------------------------------------------------------------------------
method = vae_mu
UMAP(n_neighbors=2, verbose=True)
Construct fuzzy simplicial set
Wed Jan 27 23:10:06 2021 Finding Nearest Neighbors
Wed Jan 27 23:10:06 2021 Finished Nearest Neighbor Search
Wed Jan 27 23:10:06 2021 Construct embedding
	completed  0  /  500 epochs
	completed  50  /  500 epochs
	completed  100  /  500 epochs
	completed  150  /  500 epochs
	completed  200  /  500 epochs
	completed  250  /  500 epochs
	completed  300  /  500 epochs
	completed  350  /  500 epochs
	completed  400  /  500 epochs
	completed  450  /  500 epochs
Wed Jan 27 23:10:08 2021 Finished embedding
[ 0  1  2  3  4  3  5  6  7  8  6  9  4  8  9 10 11 12 13 14 15 16 14 17
 18  6 17  9  4 11 19 20 12 21 22  8  6  8  5  1 23  6 12 14  9 23 10  6
 22  4  4  3  5 24 23 14  2 16 21 25 18  1 21  8 26 27 16 28  5 20 29 28
  0  4 23 14  4  4 14 27  6 30 27 28 16 10 23 28 19 26 17 12 11 16 16  8
 23  8 13  5  6  8 22 17 23 26 31 32  4  3 33 34 23 35 35 21 15 17 12 12
 36  7 37 30 38 38 38 38 39 39 32 39 12 26  6 21 25 36  2 14  6  6 40  6
 18 26 33 10  4  4 29  0 19 19 23 19 22 18 14 19  6 11 41 14 42 42  4 43
 43 43 39  0  2 40 19 41 42 16 44 14 34 21 45 27 40 27 34 40 37 34 12 16
 23 45  6 46 11 17 46  6 12 12  6 23 23  6  6  3 45  6 46 19 37 24 37 44
 23  9 37 37 27 10 10 19 31 27]
DBSCAN(eps=1, min_samples=2)
<class 'numpy.ndarray'>
(226,)
In [6]:
for method in ['raw', 'transformed', 'vae_mu']:
    labels = q.dbscan_labels[method]
    u = q.fractions_u[method]
    plt.figure(figsize=(9, 9))
    colors = np.random.random((1000, 3))[labels]
    plt.scatter(u[:, 0], u[:, 1], c=colors, alpha=0)
    ax = plt.gca()
    texts = list(map(str, labels))    
    for i in range(len(u)):
        ax.annotate(texts[i], (u[i, 0], u[i, 1]), color=colors[i])
    # plt.text(u[:, 0], u[:, 1], texts, color=colors)
    plt.title(f'found {max(labels)} dbscan clusters')
    plt.show()
In [7]:
jitter = (np.random.rand(len(u)) - 0.5) / 25
In [8]:
import numpy as np
random_colors = np.random.rand(4, 3)
hot_indices = [39, 4, 29, 22]
patients_per_hot_index = dict()
for h in hot_indices:
    l = np.where(q.dbscan_labels['raw'] == h)[0]
    patients_per_hot_index[h] = l

for method in ['raw', 'transformed', 'vae_mu']:
    labels = q.dbscan_labels[method]
    u = q.fractions_u[method]
    plt.figure(figsize=(20, 20))
    colors = np.tile([0., 0., 0.], (len(labels), 1))
    hot_indices = [39, 4, 29, 22]
    for color, index in zip(random_colors, hot_indices):
#         matches = np.where(labels == index)
#         print(len(matches[0]))
#         colors[matches, :] = color
        colors[patients_per_hot_index[index], :] = color
#     jitter = (np.random.rand(len(u)) - 0.5) / 25
    jitter_x = jitter * (np.max(u[:, 0]) - np.min(u[:, 0]))
    jitter_y = jitter * (np.max(u[:, 1]) - np.min(u[:, 1]))
    plt.scatter(u[:, 0] + jitter_x, u[:, 1] + jitter_y, c=colors, alpha=0)
    plt.scatter(u[:, 0], u[:, 1], c=colors, alpha=0)
    ax = plt.gca()
    texts = list(map(str, labels))    
    for i in range(len(u)):
        ax.annotate(texts[i], (u[i, 0], u[i, 1]), color=colors[i])
    # plt.text(u[:, 0], u[:, 1], texts, color=colors)
    plt.title(f'found {max(labels)} dbscan clusters')
    plt.show()
In [9]:
plt.figure()
methods = ['raw', 'transformed', 'vae_mu']
for method in methods:
    a = q.dbscan_labels[method]
    print(np.unique(a, return_counts=True))
    h, bin_edges = np.histogram(a, max(a) + 1)
    bin_centers = 0.5*(bin_edges[1:] + bin_edges[:-1])
    h = [-a for a in sorted(-h)]
    plt.step(bin_centers, h, where='mid')
#     s = sorted(bins)
#     plt.bar(mids, s)
plt.show()
(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54]), array([ 2,  3,  2,  6, 10,  6,  4,  3,  3,  5,  4,  3,  6,  3,  4,  3,  5,
        6,  6,  5,  2,  3,  7,  4,  5,  7,  4,  2,  3,  9,  3,  4,  3,  3,
        4,  4,  2,  4,  2, 12,  4,  3,  5,  2,  4,  4,  5,  2,  3,  4,  4,
        3,  3,  2,  2]))
(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54]), array([ 5,  2,  3,  9,  3,  5,  2, 15,  3,  8,  5,  4,  4,  8,  4,  3,  3,
       10,  7,  3,  4,  2,  2,  5,  3,  3,  9,  2,  4,  2,  3,  4,  3,  5,
        7,  3,  2,  2,  2,  6,  2,  4,  3,  2,  4,  4,  2,  3,  3,  3,  3,
        6,  4,  2,  2]))
(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46]), array([ 4,  3,  4,  5, 12,  5, 19,  2,  8,  5,  6,  5, 10,  2, 10,  2,  8,
        6,  4,  9,  2,  6,  4, 13,  2,  2,  5,  7,  4,  2,  2,  2,  2,  2,
        4,  2,  2,  6,  4,  4,  4,  2,  3,  3,  2,  3,  3]))

Matching BDSCAN clusters

In [10]:
from sklearn.metrics import adjusted_rand_score
from operator import itemgetter

def get_m(method0, method1):
    print('-' * 100)
    n0 = max(q.dbscan_labels[method0]) + 1
    n1 = max(q.dbscan_labels[method1]) + 1    
    labels0 = q.dbscan_labels[method0]
    labels1 = q.dbscan_labels[method1]
    print('score:', adjusted_rand_score(labels0, labels1))
    plt.figure()
    plt.scatter(q.dbscan_labels[method0], q.dbscan_labels[method1])
    plt.show()
    
    counts0 = dict(zip(*np.unique(labels0, return_counts=True)))
    indices, sorted_labels0 = zip(*sorted(enumerate(labels0), key=lambda x: (-counts0[x[1]], x[1])))
#     print(sorted_labels0)
    
    seen = set()
    no_duplicates = [x for x in sorted_labels0 if not (x in seen or seen.add(x))]
    l = list(range(max(seen) + 1))
    assert len(l) == len(no_duplicates)
    my_map = zip(l, no_duplicates)
    relabeled0 = [no_duplicates.index(i) for i in sorted_labels0]
#     print(relabeled0)
    s = adjusted_rand_score(sorted_labels0, relabeled0)
    assert np.isclose(s, 1.), s
    
    sorted_labels1 = [labels1[i] for i in indices]
#     print(sorted_labels1)
    
    seen = set()
    no_duplicates = [x for x in sorted_labels1 if not (x in seen or seen.add(x))]
    l = list(range(max(seen) + 1))
    assert len(l) == len(no_duplicates)
    my_map = zip(l, no_duplicates)
    relabeled1 = [no_duplicates.index(i) for i in sorted_labels1]
#     print(relabeled0)
    s = adjusted_rand_score(sorted_labels1, relabeled1)
    assert np.isclose(s, 1.), s
#     counts1 = dict(zip(*np.unique(labels1, return_counts=True)))
#     sorted_labels1  = sorted(labels1, key=lambda x: (-counts1[x], x))
    
    plt.figure()
    print(list(zip(relabeled0, sorted_labels0)))
    print(list(zip(relabeled1, sorted_labels1)))
    plt.scatter(relabeled0, sorted_labels0)
    plt.show()
    
    m = np.zeros((n0, n1))
    for a, b in zip(labels0, labels1):
#     for a, b in zip(relabeled0, relabeled1):
        m[a, b] += 1
        
    m_rows = m.copy()
    m_cols = m.copy()
    m_rows = m_rows / np.sum(m_rows, axis=0, keepdims=True)
    m_cols = m_cols / np.sum(m_cols, axis=1, keepdims=True)
    
    plt.figure()
    plt.imshow(m, cmap='inferno')
    plt.colorbar()
    plt.show()
    
    plt.figure()
    plt.imshow(m_rows)
    plt.colorbar()
    plt.show()
    
    plt.figure()
    plt.imshow(m_cols)
    plt.colorbar()
    plt.show()
    return m

get_m('raw', 'transformed')
get_m('transformed', 'vae_mu')
get_m('vae_mu', 'raw')
----------------------------------------------------------------------------------------------------
score: 0.10661445547812556
[(0, 39), (0, 39), (0, 39), (0, 39), (0, 39), (0, 39), (0, 39), (0, 39), (0, 39), (0, 39), (0, 39), (0, 39), (1, 4), (1, 4), (1, 4), (1, 4), (1, 4), (1, 4), (1, 4), (1, 4), (1, 4), (1, 4), (2, 29), (2, 29), (2, 29), (2, 29), (2, 29), (2, 29), (2, 29), (2, 29), (2, 29), (3, 22), (3, 22), (3, 22), (3, 22), (3, 22), (3, 22), (3, 22), (4, 25), (4, 25), (4, 25), (4, 25), (4, 25), (4, 25), (4, 25), (5, 3), (5, 3), (5, 3), (5, 3), (5, 3), (5, 3), (6, 5), (6, 5), (6, 5), (6, 5), (6, 5), (6, 5), (7, 12), (7, 12), (7, 12), (7, 12), (7, 12), (7, 12), (8, 17), (8, 17), (8, 17), (8, 17), (8, 17), (8, 17), (9, 18), (9, 18), (9, 18), (9, 18), (9, 18), (9, 18), (10, 9), (10, 9), (10, 9), (10, 9), (10, 9), (11, 16), (11, 16), (11, 16), (11, 16), (11, 16), (12, 19), (12, 19), (12, 19), (12, 19), (12, 19), (13, 24), (13, 24), (13, 24), (13, 24), (13, 24), (14, 42), (14, 42), (14, 42), (14, 42), (14, 42), (15, 46), (15, 46), (15, 46), (15, 46), (15, 46), (16, 6), (16, 6), (16, 6), (16, 6), (17, 10), (17, 10), (17, 10), (17, 10), (18, 14), (18, 14), (18, 14), (18, 14), (19, 23), (19, 23), (19, 23), (19, 23), (20, 26), (20, 26), (20, 26), (20, 26), (21, 31), (21, 31), (21, 31), (21, 31), (22, 34), (22, 34), (22, 34), (22, 34), (23, 35), (23, 35), (23, 35), (23, 35), (24, 37), (24, 37), (24, 37), (24, 37), (25, 40), (25, 40), (25, 40), (25, 40), (26, 44), (26, 44), (26, 44), (26, 44), (27, 45), (27, 45), (27, 45), (27, 45), (28, 49), (28, 49), (28, 49), (28, 49), (29, 50), (29, 50), (29, 50), (29, 50), (30, 1), (30, 1), (30, 1), (31, 7), (31, 7), (31, 7), (32, 8), (32, 8), (32, 8), (33, 11), (33, 11), (33, 11), (34, 13), (34, 13), (34, 13), (35, 15), (35, 15), (35, 15), (36, 21), (36, 21), (36, 21), (37, 28), (37, 28), (37, 28), (38, 30), (38, 30), (38, 30), (39, 32), (39, 32), (39, 32), (40, 33), (40, 33), (40, 33), (41, 41), (41, 41), (41, 41), (42, 48), (42, 48), (42, 48), (43, 51), (43, 51), (43, 51), (44, 52), (44, 52), (44, 52), (45, 0), (45, 0), (46, 2), (46, 2), (47, 20), (47, 20), (48, 27), (48, 27), (49, 36), (49, 36), (50, 38), (50, 38), (51, 43), (51, 43), (52, 47), (52, 47), (53, 53), (53, 53), (54, 54), (54, 54)]
[(0, 5), (1, 42), (2, 43), (2, 43), (3, 39), (4, 53), (5, 52), (5, 52), (5, 52), (6, 33), (7, 28), (7, 28), (8, 4), (0, 5), (9, 9), (8, 4), (10, 15), (11, 21), (12, 20), (7, 28), (7, 28), (13, 7), (14, 0), (15, 47), (16, 13), (17, 1), (9, 9), (16, 13), (4, 53), (3, 39), (18, 17), (19, 23), (20, 18), (10, 15), (10, 15), (0, 5), (21, 3), (20, 18), (13, 7), (13, 7), (13, 7), (0, 5), (0, 5), (13, 7), (22, 12), (21, 3), (23, 29), (9, 9), (13, 7), (24, 16), (21, 3), (21, 3), (25, 38), (13, 7), (26, 45), (25, 38), (1, 42), (9, 9), (6, 33), (13, 7), (27, 37), (13, 7), (21, 3), (14, 0), (19, 23), (28, 19), (14, 0), (29, 2), (30, 24), (13, 7), (13, 7), (31, 30), (6, 33), (6, 33), (9, 9), (32, 10), (33, 14), (18, 17), (18, 17), (14, 0), (20, 18), (12, 20), (20, 18), (20, 18), (20, 18), (16, 13), (34, 31), (32, 10), (32, 10), (34, 31), (35, 25), (23, 29), (13, 7), (3, 39), (36, 34), (37, 41), (38, 6), (39, 54), (37, 41), (40, 36), (36, 34), (32, 10), (18, 17), (18, 17), (30, 24), (38, 6), (41, 35), (6, 33), (27, 37), (42, 11), (18, 17), (43, 26), (13, 7), (33, 14), (18, 17), (12, 20), (21, 3), (30, 24), (43, 26), (36, 34), (36, 34), (31, 30), (44, 32), (9, 9), (3, 39), (43, 26), (33, 14), (42, 11), (45, 27), (46, 8), (35, 25), (15, 47), (9, 9), (19, 23), (47, 49), (47, 49), (47, 49), (13, 7), (44, 32), (1, 42), (3, 39), (48, 44), (48, 44), (48, 44), (48, 44), (16, 13), (49, 48), (50, 51), (50, 51), (41, 35), (36, 34), (18, 17), (18, 17), (50, 51), (36, 34), (49, 48), (36, 34), (5, 52), (3, 39), (43, 26), (43, 26), (17, 1), (16, 13), (42, 11), (13, 7), (11, 21), (42, 11), (46, 8), (28, 19), (43, 26), (22, 12), (22, 12), (45, 27), (16, 13), (51, 22), (34, 31), (24, 16), (43, 26), (8, 4), (19, 23), (19, 23), (40, 36), (31, 30), (24, 16), (21, 3), (29, 2), (18, 17), (41, 35), (43, 26), (35, 25), (33, 14), (44, 32), (28, 19), (21, 3), (26, 45), (26, 45), (26, 45), (52, 50), (52, 50), (52, 50), (50, 51), (50, 51), (50, 51), (22, 12), (16, 13), (32, 10), (14, 0), (46, 8), (29, 2), (15, 47), (9, 9), (21, 3), (34, 31), (20, 18), (53, 40), (53, 40), (16, 13), (37, 41), (54, 46), (54, 46), (37, 41), (49, 48), (43, 26), (12, 20), (39, 54), (51, 22)]
----------------------------------------------------------------------------------------------------
score: 0.17162192099917203
[(0, 7), (0, 7), (0, 7), (0, 7), (0, 7), (0, 7), (0, 7), (0, 7), (0, 7), (0, 7), (0, 7), (0, 7), (0, 7), (0, 7), (0, 7), (1, 17), (1, 17), (1, 17), (1, 17), (1, 17), (1, 17), (1, 17), (1, 17), (1, 17), (1, 17), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3), (3, 26), (3, 26), (3, 26), (3, 26), (3, 26), (3, 26), (3, 26), (3, 26), (3, 26), (4, 9), (4, 9), (4, 9), (4, 9), (4, 9), (4, 9), (4, 9), (4, 9), (5, 13), (5, 13), (5, 13), (5, 13), (5, 13), (5, 13), (5, 13), (5, 13), (6, 18), (6, 18), (6, 18), (6, 18), (6, 18), (6, 18), (6, 18), (7, 34), (7, 34), (7, 34), (7, 34), (7, 34), (7, 34), (7, 34), (8, 39), (8, 39), (8, 39), (8, 39), (8, 39), (8, 39), (9, 51), (9, 51), (9, 51), (9, 51), (9, 51), (9, 51), (10, 0), (10, 0), (10, 0), (10, 0), (10, 0), (11, 5), (11, 5), (11, 5), (11, 5), (11, 5), (12, 10), (12, 10), (12, 10), (12, 10), (12, 10), (13, 23), (13, 23), (13, 23), (13, 23), (13, 23), (14, 33), (14, 33), (14, 33), (14, 33), (14, 33), (15, 11), (15, 11), (15, 11), (15, 11), (16, 12), (16, 12), (16, 12), (16, 12), (17, 14), (17, 14), (17, 14), (17, 14), (18, 20), (18, 20), (18, 20), (18, 20), (19, 28), (19, 28), (19, 28), (19, 28), (20, 31), (20, 31), (20, 31), (20, 31), (21, 41), (21, 41), (21, 41), (21, 41), (22, 44), (22, 44), (22, 44), (22, 44), (23, 45), (23, 45), (23, 45), (23, 45), (24, 52), (24, 52), (24, 52), (24, 52), (25, 2), (25, 2), (25, 2), (26, 4), (26, 4), (26, 4), (27, 8), (27, 8), (27, 8), (28, 15), (28, 15), (28, 15), (29, 16), (29, 16), (29, 16), (30, 19), (30, 19), (30, 19), (31, 24), (31, 24), (31, 24), (32, 25), (32, 25), (32, 25), (33, 30), (33, 30), (33, 30), (34, 32), (34, 32), (34, 32), (35, 35), (35, 35), (35, 35), (36, 42), (36, 42), (36, 42), (37, 47), (37, 47), (37, 47), (38, 48), (38, 48), (38, 48), (39, 49), (39, 49), (39, 49), (40, 50), (40, 50), (40, 50), (41, 1), (41, 1), (42, 6), (42, 6), (43, 21), (43, 21), (44, 22), (44, 22), (45, 27), (45, 27), (46, 29), (46, 29), (47, 36), (47, 36), (48, 37), (48, 37), (49, 38), (49, 38), (50, 40), (50, 40), (51, 43), (51, 43), (52, 46), (52, 46), (53, 53), (53, 53), (54, 54), (54, 54)]
[(0, 7), (1, 17), (2, 23), (3, 10), (4, 4), (1, 17), (5, 31), (4, 4), (6, 21), (7, 15), (1, 17), (8, 36), (9, 39), (6, 21), (10, 45), (11, 16), (12, 12), (13, 6), (11, 16), (14, 8), (11, 16), (15, 44), (11, 16), (13, 6), (15, 44), (16, 3), (17, 5), (6, 21), (17, 5), (18, 26), (19, 32), (16, 3), (4, 4), (16, 3), (13, 6), (20, 27), (11, 16), (20, 27), (20, 27), (11, 16), (13, 6), (3, 10), (3, 10), (13, 6), (3, 10), (21, 20), (4, 4), (4, 4), (3, 10), (13, 6), (13, 6), (22, 11), (12, 12), (23, 9), (12, 12), (24, 40), (25, 22), (12, 12), (26, 46), (27, 14), (27, 14), (27, 14), (27, 14), (27, 14), (27, 14), (10, 45), (2, 23), (28, 18), (3, 10), (2, 23), (2, 23), (13, 6), (2, 23), (29, 34), (18, 26), (29, 34), (29, 34), (30, 24), (31, 37), (24, 40), (20, 27), (24, 40), (20, 27), (24, 40), (12, 12), (32, 0), (1, 17), (33, 25), (32, 0), (17, 5), (16, 3), (8, 36), (4, 4), (4, 4), (10, 45), (23, 9), (34, 19), (34, 19), (34, 19), (34, 19), (12, 12), (25, 22), (14, 8), (14, 8), (1, 17), (21, 20), (35, 28), (35, 28), (35, 28), (31, 37), (14, 8), (36, 1), (11, 16), (14, 8), (23, 9), (23, 9), (32, 0), (12, 12), (37, 13), (22, 11), (22, 11), (37, 13), (13, 6), (13, 6), (13, 6), (13, 6), (4, 4), (4, 4), (31, 37), (5, 31), (38, 2), (32, 0), (39, 41), (2, 23), (12, 12), (12, 12), (22, 11), (22, 11), (40, 38), (40, 38), (40, 38), (40, 38), (9, 39), (9, 39), (19, 32), (9, 39), (41, 42), (31, 37), (29, 34), (31, 37), (38, 2), (28, 18), (2, 23), (4, 4), (4, 4), (30, 24), (14, 8), (42, 29), (12, 12), (27, 14), (27, 14), (27, 14), (7, 15), (6, 21), (6, 21), (28, 18), (17, 5), (2, 23), (36, 1), (20, 27), (20, 27), (2, 23), (2, 23), (2, 23), (2, 23), (6, 21), (35, 28), (18, 26), (18, 26), (18, 26), (13, 6), (11, 16), (13, 6), (0, 7), (38, 2), (1, 17), (13, 6), (34, 19), (34, 19), (42, 29), (39, 41), (2, 23), (13, 6), (41, 42), (41, 42), (43, 43), (43, 43), (43, 43), (36, 1), (28, 18), (13, 6), (13, 6), (4, 4), (44, 30), (34, 19), (34, 19), (25, 22), (25, 22), (16, 3), (17, 5), (34, 19), (26, 46), (14, 8), (14, 8), (45, 33), (45, 33), (46, 35), (46, 35), (31, 37), (44, 30), (33, 25), (38, 2), (27, 14), (13, 6), (26, 46), (23, 9)]
----------------------------------------------------------------------------------------------------
score: 0.09578921667908091
[(0, 6), (0, 6), (0, 6), (0, 6), (0, 6), (0, 6), (0, 6), (0, 6), (0, 6), (0, 6), (0, 6), (0, 6), (0, 6), (0, 6), (0, 6), (0, 6), (0, 6), (0, 6), (0, 6), (1, 23), (1, 23), (1, 23), (1, 23), (1, 23), (1, 23), (1, 23), (1, 23), (1, 23), (1, 23), (1, 23), (1, 23), (1, 23), (2, 4), (2, 4), (2, 4), (2, 4), (2, 4), (2, 4), (2, 4), (2, 4), (2, 4), (2, 4), (2, 4), (2, 4), (3, 12), (3, 12), (3, 12), (3, 12), (3, 12), (3, 12), (3, 12), (3, 12), (3, 12), (3, 12), (4, 14), (4, 14), (4, 14), (4, 14), (4, 14), (4, 14), (4, 14), (4, 14), (4, 14), (4, 14), (5, 19), (5, 19), (5, 19), (5, 19), (5, 19), (5, 19), (5, 19), (5, 19), (5, 19), (6, 8), (6, 8), (6, 8), (6, 8), (6, 8), (6, 8), (6, 8), (6, 8), (7, 16), (7, 16), (7, 16), (7, 16), (7, 16), (7, 16), (7, 16), (7, 16), (8, 27), (8, 27), (8, 27), (8, 27), (8, 27), (8, 27), (8, 27), (9, 10), (9, 10), (9, 10), (9, 10), (9, 10), (9, 10), (10, 17), (10, 17), (10, 17), (10, 17), (10, 17), (10, 17), (11, 21), (11, 21), (11, 21), (11, 21), (11, 21), (11, 21), (12, 37), (12, 37), (12, 37), (12, 37), (12, 37), (12, 37), (13, 3), (13, 3), (13, 3), (13, 3), (13, 3), (14, 5), (14, 5), (14, 5), (14, 5), (14, 5), (15, 9), (15, 9), (15, 9), (15, 9), (15, 9), (16, 11), (16, 11), (16, 11), (16, 11), (16, 11), (17, 26), (17, 26), (17, 26), (17, 26), (17, 26), (18, 0), (18, 0), (18, 0), (18, 0), (19, 2), (19, 2), (19, 2), (19, 2), (20, 18), (20, 18), (20, 18), (20, 18), (21, 22), (21, 22), (21, 22), (21, 22), (22, 28), (22, 28), (22, 28), (22, 28), (23, 34), (23, 34), (23, 34), (23, 34), (24, 38), (24, 38), (24, 38), (24, 38), (25, 39), (25, 39), (25, 39), (25, 39), (26, 40), (26, 40), (26, 40), (26, 40), (27, 1), (27, 1), (27, 1), (28, 42), (28, 42), (28, 42), (29, 43), (29, 43), (29, 43), (30, 45), (30, 45), (30, 45), (31, 46), (31, 46), (31, 46), (32, 7), (32, 7), (33, 13), (33, 13), (34, 15), (34, 15), (35, 20), (35, 20), (36, 24), (36, 24), (37, 25), (37, 25), (38, 29), (38, 29), (39, 30), (39, 30), (40, 31), (40, 31), (41, 32), (41, 32), (42, 33), (42, 33), (43, 35), (43, 35), (44, 36), (44, 36), (45, 41), (45, 41), (46, 44), (46, 44)]
[(0, 6), (1, 4), (2, 16), (1, 4), (3, 15), (4, 9), (0, 6), (5, 14), (6, 42), (7, 34), (7, 34), (8, 45), (9, 35), (8, 45), (10, 29), (10, 29), (11, 53), (11, 53), (12, 24), (12, 24), (13, 18), (14, 26), (15, 23), (16, 32), (17, 33), (7, 34), (18, 17), (15, 23), (19, 49), (20, 19), (19, 49), (19, 49), (1, 4), (1, 4), (1, 4), (1, 4), (1, 4), (13, 18), (21, 10), (14, 26), (22, 25), (22, 25), (22, 25), (23, 3), (24, 1), (25, 21), (5, 14), (7, 34), (26, 38), (26, 38), (6, 42), (27, 51), (10, 29), (28, 52), (1, 4), (2, 16), (2, 16), (2, 16), (29, 22), (2, 16), (29, 22), (30, 27), (29, 22), (31, 39), (32, 13), (25, 21), (20, 19), (20, 19), (10, 29), (33, 46), (34, 2), (28, 52), (35, 54), (36, 8), (21, 10), (25, 21), (18, 17), (37, 30), (38, 12), (0, 6), (39, 31), (21, 10), (4, 9), (21, 10), (16, 32), (40, 7), (37, 30), (33, 46), (8, 45), (15, 23), (36, 8), (39, 31), (41, 44), (27, 51), (18, 17), (33, 46), (38, 12), (22, 25), (23, 3), (8, 45), (42, 50), (42, 50), (18, 17), (13, 18), (38, 12), (9, 35), (43, 37), (43, 37), (44, 20), (45, 28), (45, 28), (46, 5), (23, 3), (12, 24), (31, 39), (31, 39), (31, 39), (10, 29), (31, 39), (31, 39), (23, 3), (1, 4), (12, 24), (38, 12), (29, 22), (46, 5), (18, 17), (23, 3), (45, 28), (18, 17), (4, 9), (47, 11), (20, 19), (47, 11), (35, 54), (32, 13), (4, 9), (39, 31), (48, 47), (6, 42), (14, 26), (17, 33), (5, 14), (43, 37), (31, 39), (49, 0), (4, 9), (20, 19), (22, 25), (34, 2), (30, 27), (50, 43), (46, 5), (36, 8), (37, 30), (33, 46), (10, 29), (29, 22), (47, 11), (39, 31), (10, 29), (13, 18), (13, 18), (38, 12), (0, 6), (14, 26), (42, 50), (12, 24), (31, 39), (51, 40), (51, 40), (51, 40), (51, 40), (46, 5), (52, 41), (52, 41), (22, 25), (41, 44), (19, 49), (41, 44), (27, 51), (24, 1), (15, 23), (24, 1), (9, 35), (9, 35), (42, 50), (53, 48), (53, 48), (53, 48), (1, 4), (29, 22), (29, 22), (6, 42), (6, 42), (28, 52), (40, 7), (31, 39), (5, 14), (16, 32), (3, 15), (23, 3), (44, 20), (13, 18), (3, 15), (43, 37), (10, 29), (50, 43), (49, 0), (41, 44), (40, 7), (31, 39), (38, 12), (31, 39), (17, 33), (52, 41), (46, 5), (46, 5), (54, 36), (54, 36), (31, 39), (22, 25), (48, 47), (32, 13), (33, 46), (10, 29)]
Out[10]:
array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 2., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.]])

Metadata

Coloring by cohort

In [11]:
from metadata import get_metadata
df = get_metadata()
display(df)
print(df['Subtype'].value_counts())
discarding 153 omes from the Basel cohort, remaining: 223
discarding 153 omes from the Zurich cohort, remaning: 229
clearing metadata
warning: df_basel[Subtype] contains 5 NAs out of 223 values
warning: df_basel[clinical_type] contains 5 NAs out of 223 values
warning: df_zurich[clinical_type] contains 34 NAs out of 229 values
flattening PTNM_T labels
warning: interpreting the PTNM_T label "[]" as "TX"
flattening PTNM_N labels
warning: interpreting the PTNM_N label "[]" as "pNX"
warning: interpreting the PTNM_N labels "0sl" and "0sn" as "pN0"
warning: interpreting the "M0_IPLUS" label as "cMo(i+)"
metadata cleaned
FileName_FullStack merged_pid diseasestatus PrimarySite Subtype clinical_type Height_FullStack Width_FullStack area sum_area_cells Count_Cells PTNM_T PTNM_N PTNM_M DFSmonth OSmonth images_per_patient images_per_patient_filtered cohort
2 BaselTMA_SP41_15.475kx12.665ky_10000x8500_5_20... 166 tumor breast PR-ER- TripleNeg 723 749 541527 356411 3068 T2 pN0 M0 35.0 37.0 1 1 basel
3 BaselTMA_SP41_25.475kx12.665ky_8000x8500_3_201... 238 tumor breast PR+ER+ HR+HER2- 840 712 598080 286198 3173 T1 pN2 M0 140.0 233.0 2 2 basel
4 BaselTMA_SP41_25.475kx12.665ky_8000x8500_3_201... 238 non-tumor breast PR+ER+ HR+HER2- 765 689 527085 193119 2121 T1 pN2 M0 140.0 233.0 2 2 basel
6 BaselTMA_SP41_25.475kx12.665ky_8000x8500_3_201... 68 tumor breast PR+ER+ HR+HER2- 689 688 474032 218846 2262 T1 pN2 M0 169.0 169.0 2 1 basel
7 BaselTMA_SP41_15.475kx12.665ky_10000x8500_5_20... 72 tumor breast PR-ER+ HR+HER2- 716 737 527692 289717 2740 T1 pN0 pM1 186.0 186.0 2 1 basel
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
350 ZTMA208_slide_13.25kx14.95ky_7000x7000_8_20171... 308 tumor breast PR+ER+ NaN 492 514 252888 243839 3674 TX pNX pM1 NaN NaN 6 4 zurich
354 ZTMA208_slide_28.23kx22.4ky_7000x7000_5_201711... 301 tumor breast PR+ER+ NaN 503 475 238925 98597 1256 T3 pN1 pM1 NaN NaN 6 4 zurich
355 ZTMA208_slide_28.23kx22.4ky_7000x7000_5_201711... 301 tumor breast PR+ER+ NaN 450 541 243450 102638 1227 T3 pN1 pM1 NaN NaN 6 4 zurich
356 ZTMA208_slide_28.23kx22.4ky_7000x7000_5_201711... 301 tumor breast PR+ER+ NaN 464 519 240816 123359 1580 T3 pN1 pM1 NaN NaN 6 4 zurich
357 ZTMA208_slide_20.73kx15.16ky_7000x7000_6_20171... 301 tumor breast PR+ER+ NaN 488 540 263520 241444 3418 T3 pN1 pM1 NaN NaN 6 4 zurich

452 rows × 19 columns

PR+ER+    304
PR-ER-     68
PR-ER+     66
PR+ER-      9
Name: Subtype, dtype: int64
In [12]:
cohort_info = df.loc[df['FileName_FullStack'].isin(q.omes['raw']), 'cohort'].apply(lambda x: 0 if x == 'basel' else 1).to_numpy()
def f(x):
    assert type(x) == float and np.isnan(x) or x in ['PR+ER+', 'PR-ER-', 'PR-ER+', 'PR+ER-'], x
    if x == 'PR+ER+':
        return 0
    if x == 'PR-ER-':
        return 1
    if x == 'PR-ER+':
        return 2
    if x == 'PR+ER-':
        return 3
    if type(x) == float and np.isnan(x):
        return 4
    assert False
tumor_info = df.loc[df['FileName_FullStack'].isin(q.omes['raw']), 'Subtype'].apply(f).to_numpy()
print(cohort_info.shape)
print(tumor_info.shape)
(226,)
(226,)
In [13]:
random_colors = np.random.rand(4, 3)

for method in ['raw', 'transformed', 'vae_mu']:
    labels = q.dbscan_labels[method]
    u = q.fractions_u[method]
    plt.figure(figsize=(9, 9))
    colors = np.tile([0., 0., 0.], (len(labels), 1))
    two_colors = np.array([[1., 0., 0.], [0., 0., 1.]])
    hot_indices = [39, 4, 29, 22]
    for color, index in zip(random_colors, hot_indices):
        matches = np.where(labels == index)
        print(len(matches[0]))
        colors[matches, :] = color
    plt.scatter(u[:, 0], u[:, 1], c=cohort_info, alpha=0)
    ax = plt.gca()
    texts = list(map(str, labels))    
    cohort_colors = two_colors[cohort_info, :]
    for i in range(len(u)):
        ax.annotate(texts[i], (u[i, 0], u[i, 1]), color=cohort_colors[i, :])
    # plt.text(u[:, 0], u[:, 1], texts, color=colors)
    plt.title(f'found {max(labels)} dbscan clusters')
    plt.show()
12
10
9
7
6
3
2
2
4
12
2
4

Coloring by cancer type

In [14]:
random_colors = np.random.rand(5, 3)

for method in ['raw', 'transformed', 'vae_mu']:
    labels = q.dbscan_labels[method]
    u = q.fractions_u[method]
    plt.figure(figsize=(9, 9))
    colors = np.tile([0., 0., 0.], (len(labels), 1))
    hot_indices = [39, 4, 29, 22]
    for color, index in zip(random_colors, hot_indices):
        matches = np.where(labels == index)
        print(len(matches[0]))
        colors[matches, :] = color
    plt.scatter(u[:, 0], u[:, 1], alpha=0)
    ax = plt.gca()
    texts = list(map(str, labels))    
    cohort_colors = random_colors[tumor_info, :]
    for i in range(len(u)):
        ax.annotate(texts[i], (u[i, 0], u[i, 1]), color=cohort_colors[i, :])
    # plt.text(u[:, 0], u[:, 1], texts, color=colors)
    plt.title(f'found {max(labels)} dbscan clusters')
    from matplotlib.lines import Line2D
    custom_lines = [Line2D([0], [0], color=random_colors[i], lw=4) for i in range(5)]
    ax.legend(custom_lines, ['PR+ER+', 'PR-ER-', 'PR-ER+', 'PR+ER-', 'nan'])
    plt.show()
12
10
9
7
6
3
2
2
4
12
2
4

Coloring by patient

In [15]:
df
Out[15]:
FileName_FullStack merged_pid diseasestatus PrimarySite Subtype clinical_type Height_FullStack Width_FullStack area sum_area_cells Count_Cells PTNM_T PTNM_N PTNM_M DFSmonth OSmonth images_per_patient images_per_patient_filtered cohort
2 BaselTMA_SP41_15.475kx12.665ky_10000x8500_5_20... 166 tumor breast PR-ER- TripleNeg 723 749 541527 356411 3068 T2 pN0 M0 35.0 37.0 1 1 basel
3 BaselTMA_SP41_25.475kx12.665ky_8000x8500_3_201... 238 tumor breast PR+ER+ HR+HER2- 840 712 598080 286198 3173 T1 pN2 M0 140.0 233.0 2 2 basel
4 BaselTMA_SP41_25.475kx12.665ky_8000x8500_3_201... 238 non-tumor breast PR+ER+ HR+HER2- 765 689 527085 193119 2121 T1 pN2 M0 140.0 233.0 2 2 basel
6 BaselTMA_SP41_25.475kx12.665ky_8000x8500_3_201... 68 tumor breast PR+ER+ HR+HER2- 689 688 474032 218846 2262 T1 pN2 M0 169.0 169.0 2 1 basel
7 BaselTMA_SP41_15.475kx12.665ky_10000x8500_5_20... 72 tumor breast PR-ER+ HR+HER2- 716 737 527692 289717 2740 T1 pN0 pM1 186.0 186.0 2 1 basel
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
350 ZTMA208_slide_13.25kx14.95ky_7000x7000_8_20171... 308 tumor breast PR+ER+ NaN 492 514 252888 243839 3674 TX pNX pM1 NaN NaN 6 4 zurich
354 ZTMA208_slide_28.23kx22.4ky_7000x7000_5_201711... 301 tumor breast PR+ER+ NaN 503 475 238925 98597 1256 T3 pN1 pM1 NaN NaN 6 4 zurich
355 ZTMA208_slide_28.23kx22.4ky_7000x7000_5_201711... 301 tumor breast PR+ER+ NaN 450 541 243450 102638 1227 T3 pN1 pM1 NaN NaN 6 4 zurich
356 ZTMA208_slide_28.23kx22.4ky_7000x7000_5_201711... 301 tumor breast PR+ER+ NaN 464 519 240816 123359 1580 T3 pN1 pM1 NaN NaN 6 4 zurich
357 ZTMA208_slide_20.73kx15.16ky_7000x7000_6_20171... 301 tumor breast PR+ER+ NaN 488 540 263520 241444 3418 T3 pN1 pM1 NaN NaN 6 4 zurich

452 rows × 19 columns

In [16]:
patient_of_origin = df.loc[df['FileName_FullStack'].isin(q.omes['raw']), 'merged_pid'].to_numpy()
patient_of_origin
Out[16]:
array([166, 238, 238, 206, 206, 280, 220,  24,  24, 284, 100, 177, 282,
        45, 211, 239, 227, 234,  33,  39,  61,  44, 159,  38, 114, 268,
       256,  55,  60,  40,  14,   2, 112, 129, 203, 101,  41, 255, 249,
       215,  36,  84, 105,   6,  51, 120, 253, 253, 251, 217,  70,  10,
       264,  54, 188,  74, 192, 189, 154, 183, 130, 109, 174, 232, 148,
        81, 279, 229, 258,   5, 223,  73, 182, 198, 164, 121,  12, 162,
        87,   1, 233, 243, 278, 278,  34, 144, 180, 260, 260, 119, 201,
        67,  75,  23, 272, 200, 242, 207,  26, 140, 195, 252, 213,  95,
        19, 170, 209,  50,  77, 110, 291, 297, 297, 297, 297, 343, 343,
       344, 344, 344, 344, 335, 335, 335, 340, 340, 340, 340, 340, 357,
       357, 357, 357, 311, 311, 311, 311, 293, 293, 293, 309, 309, 309,
       309, 309, 338, 338, 338, 338, 348, 348, 290, 290, 290, 290, 303,
       303, 303, 303, 313, 313, 313, 339, 339, 339, 339, 322, 322, 322,
       286, 286, 287, 287, 287, 326, 326, 326, 315, 337, 337, 337, 307,
       307, 307, 307, 320, 320, 320, 320, 299, 299, 355, 355, 352, 352,
       352, 352, 325, 325, 325, 336, 336, 336, 336, 305, 305, 305, 305,
       332, 332, 332, 332, 329, 329, 316, 316, 316, 316, 351, 351, 351,
       351, 301, 301, 301, 301])
In [17]:
random_colors = np.random.rand(1000, 3)

for method in ['raw', 'transformed', 'vae_mu']:
    labels = q.dbscan_labels[method]
    u = q.fractions_u[method]
    plt.figure(figsize=(9, 9))
    plt.scatter(u[:, 0], u[:, 1], alpha=0)
    ax = plt.gca()
    texts = list(map(str, labels))    
    cohort_colors = random_colors[patient_of_origin, :]
    for i in range(len(u)):
        ax.annotate(texts[i], (u[i, 0], u[i, 1]), color=cohort_colors[i, :])
    # plt.text(u[:, 0], u[:, 1], texts, color=colors)
    plt.title(f'found {max(labels)} dbscan clusters')
    plt.show()
In [18]:
from sklearn.metrics import adjusted_rand_score
from operator import itemgetter

def get_m_patients(method0):
    print('-' * 100)
    n0 = max(q.dbscan_labels[method0]) + 1
    labels0 = q.dbscan_labels[method0]
    print('score:', adjusted_rand_score(labels0, patient_of_origin))
    plt.figure()
    plt.scatter(q.dbscan_labels[method0], patient_of_origin, c=patient_of_origin)
    plt.show()
    
    counts0 = dict(zip(*np.unique(labels0, return_counts=True)))
    indices, sorted_labels0 = zip(*sorted(enumerate(labels0), key=lambda x: (-counts0[x[1]], x[1])))
#     print(sorted_labels0)
    
    seen = set()
    no_duplicates = [x for x in sorted_labels0 if not (x in seen or seen.add(x))]
    l = list(range(max(seen) + 1))
    assert len(l) == len(no_duplicates)
    my_map = zip(l, no_duplicates)
    relabeled0 = [no_duplicates.index(i) for i in sorted_labels0]
#     print(relabeled0)
    s = adjusted_rand_score(sorted_labels0, relabeled0)
    assert np.isclose(s, 1.), s
    
#     print(labels0)
#     print(sorted_labels0)
#     print(relabeled0)
    print(indices)
    plt.figure()
    plt.scatter(np.arange(len(indices)), indices, c=cohort_info)
    ax = plt.gca()
    ax.set_aspect('equal')
    plt.show()
    sorted_patient_of_origin = [patient_of_origin[i] for i in indices]

    make_indices_adjacent = dict(zip(sorted(list(set(sorted_patient_of_origin))), list(range(len(set(sorted_patient_of_origin))))))
    
    plt.figure()
    plt.scatter(relabeled0, [make_indices_adjacent[s] for s in sorted_patient_of_origin], c=[cohort_info[i] for i in indices])
    plt.show()
    return

    m = np.zeros((n0, len(make_indices_adjacent)))
    for a, b in zip(relabeled0, sorted_patient_of_origin):
#     for a, b in zip(relabeled0, relabeled1):
        m[a, make_indices_adjacent[b]] += 1
        
    m_rows = m.copy()
    m_cols = m.copy()
    m_rows = m_rows / np.sum(m_rows, axis=0, keepdims=True)
    m_cols = m_cols / np.sum(m_cols, axis=1, keepdims=True)

    plt.figure(figsize=(20, 9))
    plt.imshow(m, cmap='inferno')
    plt.colorbar()
    plt.show()
    
    plt.figure(figsize=(20, 9))
    plt.imshow(m_rows)
    plt.colorbar()
    plt.show()
    
    plt.figure(figsize=(20, 9))
    plt.imshow(m_cols)
    plt.colorbar()
    plt.show()
    return m

get_m_patients('raw')
# get_m_patients('transformed')
# get_m_patients('vae_mu')
----------------------------------------------------------------------------------------------------
score: 0.09141961348283346
(120, 121, 122, 123, 145, 179, 188, 189, 212, 218, 219, 224, 4, 5, 10, 12, 19, 28, 36, 49, 50, 182, 59, 155, 156, 157, 199, 200, 202, 214, 215, 34, 75, 139, 163, 193, 207, 208, 46, 108, 137, 148, 149, 170, 171, 3, 52, 85, 116, 135, 166, 6, 110, 115, 128, 146, 172, 15, 83, 90, 95, 106, 109, 23, 37, 38, 99, 112, 220, 26, 45, 67, 69, 71, 73, 11, 29, 47, 57, 72, 22, 25, 43, 55, 78, 27, 151, 152, 153, 203, 40, 51, 181, 186, 209, 132, 134, 195, 196, 198, 144, 159, 177, 178, 225, 7, 80, 87, 97, 13, 21, 66, 76, 18, 42, 100, 105, 39, 65, 74, 154, 54, 64, 77, 111, 82, 92, 101, 102, 91, 104, 140, 141, 103, 160, 164, 165, 117, 133, 197, 213, 124, 125, 126, 127, 142, 150, 183, 184, 143, 147, 191, 194, 173, 192, 204, 216, 176, 180, 221, 222, 1, 17, 61, 8, 81, 93, 9, 24, 79, 14, 44, 48, 16, 30, 175, 20, 41, 53, 32, 35, 88, 58, 62, 68, 60, 63, 94, 84, 86, 98, 89, 96, 107, 129, 130, 131, 167, 168, 169, 185, 187, 190, 201, 210, 211, 0, 70, 2, 174, 31, 33, 56, 158, 113, 114, 118, 119, 136, 138, 161, 162, 205, 206, 217, 223)
In [19]:
from sklearn.metrics import adjusted_rand_score
from operator import itemgetter

method0 = 'raw'
print('-' * 100)
n0 = max(q.dbscan_labels[method0]) + 1
labels0 = q.dbscan_labels[method0]
print('score:', adjusted_rand_score(labels0, patient_of_origin))
plt.figure()
plt.scatter(q.dbscan_labels[method0], patient_of_origin, c=patient_of_origin)
plt.show()

counts0 = dict(zip(*np.unique(labels0, return_counts=True)))
indices, sorted_labels0 = zip(*sorted(enumerate(labels0), key=lambda x: (-counts0[x[1]], x[1])))
print(len(sorted_labels0))
print(len(cohort_info))
print(len(patient_of_origin))
make_indices_adjacent = dict(zip(sorted(list(set(patient_of_origin))), list(range(len(set(patient_of_origin))))))
print(len(make_indices_adjacent))
patient_of_origin_adjacent = [make_indices_adjacent[p] for p in patient_of_origin]
print(len(patient_of_origin_adjacent))
plt.figure()
plt.scatter(patient_of_origin, patient_of_origin_adjacent)
plt.show()

plt.figure()
plt.scatter(q.dbscan_labels[method0], patient_of_origin_adjacent, c=cohort_info)
plt.show()

sorted_patient_of_origin_adjacent = [patient_of_origin_adjacent[i] for i in indices]
sorted_cohort_info = [cohort_info[i] for i in indices]
print(len(sorted_patient_of_origin_adjacent))
print(len(sorted_cohort_info))
plt.figure()
plt.scatter(np.arange(len(sorted_patient_of_origin_adjacent)), sorted_patient_of_origin_adjacent, c=sorted_cohort_info)
plt.show()
----------------------------------------------------------------------------------------------------
score: 0.09141961348283346
226
226
226
139
226
226
226
In [20]:
seen = set()
no_duplicates = [x for x in sorted_labels0 if not (x in seen or seen.add(x))]
l = list(range(max(seen) + 1))
assert len(l) == len(no_duplicates)
my_map = zip(l, no_duplicates)
relabeled0 = [no_duplicates.index(i) for i in sorted_labels0]
#     print(relabeled0)
s = adjusted_rand_score(sorted_labels0, relabeled0)
assert np.isclose(s, 1.), s

#     print(labels0)
#     print(sorted_labels0)
#     print(relabeled0)

plt.figure()
plt.scatter(np.arange(len(indices)), indices, c=cohort_info)
ax = plt.gca()
ax.set_aspect('equal')
plt.show()

plt.figure()
plt.scatter(relabeled0, sorted_patient_of_origin_adjacent, c=sorted_cohort_info)
plt.show()
In [21]:
m = np.zeros((n0, len(make_indices_adjacent)))
for a, b in zip(relabeled0, sorted_patient_of_origin_adjacent):
#     for a, b in zip(relabeled0, relabeled1):
    m[a, b] += 1

m_rows = m.copy()
m_cols = m.copy()
m_rows = m_rows / np.sum(m_rows, axis=0, keepdims=True)
m_cols = m_cols / np.sum(m_cols, axis=1, keepdims=True)

plt.figure(figsize=(20, 9))
plt.imshow(m, cmap='inferno')
plt.colorbar()
plt.show()

plt.figure(figsize=(20, 9))
plt.imshow(m_rows)
plt.colorbar()
plt.show()

plt.figure(figsize=(20, 9))
plt.imshow(m_cols)
plt.colorbar()
plt.show()

Looking at the DBSCAN clusters that are more numerous in raw data

In [22]:
df
Out[22]:
FileName_FullStack merged_pid diseasestatus PrimarySite Subtype clinical_type Height_FullStack Width_FullStack area sum_area_cells Count_Cells PTNM_T PTNM_N PTNM_M DFSmonth OSmonth images_per_patient images_per_patient_filtered cohort
2 BaselTMA_SP41_15.475kx12.665ky_10000x8500_5_20... 166 tumor breast PR-ER- TripleNeg 723 749 541527 356411 3068 T2 pN0 M0 35.0 37.0 1 1 basel
3 BaselTMA_SP41_25.475kx12.665ky_8000x8500_3_201... 238 tumor breast PR+ER+ HR+HER2- 840 712 598080 286198 3173 T1 pN2 M0 140.0 233.0 2 2 basel
4 BaselTMA_SP41_25.475kx12.665ky_8000x8500_3_201... 238 non-tumor breast PR+ER+ HR+HER2- 765 689 527085 193119 2121 T1 pN2 M0 140.0 233.0 2 2 basel
6 BaselTMA_SP41_25.475kx12.665ky_8000x8500_3_201... 68 tumor breast PR+ER+ HR+HER2- 689 688 474032 218846 2262 T1 pN2 M0 169.0 169.0 2 1 basel
7 BaselTMA_SP41_15.475kx12.665ky_10000x8500_5_20... 72 tumor breast PR-ER+ HR+HER2- 716 737 527692 289717 2740 T1 pN0 pM1 186.0 186.0 2 1 basel
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
350 ZTMA208_slide_13.25kx14.95ky_7000x7000_8_20171... 308 tumor breast PR+ER+ NaN 492 514 252888 243839 3674 TX pNX pM1 NaN NaN 6 4 zurich
354 ZTMA208_slide_28.23kx22.4ky_7000x7000_5_201711... 301 tumor breast PR+ER+ NaN 503 475 238925 98597 1256 T3 pN1 pM1 NaN NaN 6 4 zurich
355 ZTMA208_slide_28.23kx22.4ky_7000x7000_5_201711... 301 tumor breast PR+ER+ NaN 450 541 243450 102638 1227 T3 pN1 pM1 NaN NaN 6 4 zurich
356 ZTMA208_slide_28.23kx22.4ky_7000x7000_5_201711... 301 tumor breast PR+ER+ NaN 464 519 240816 123359 1580 T3 pN1 pM1 NaN NaN 6 4 zurich
357 ZTMA208_slide_20.73kx15.16ky_7000x7000_6_20171... 301 tumor breast PR+ER+ NaN 488 540 263520 241444 3418 T3 pN1 pM1 NaN NaN 6 4 zurich

452 rows × 19 columns

In [23]:
import torch

def get_filtered_labels_mapping(ome_index, ome_filename, split):
    import pickle
    f = os.path.join('data/spatial_uzh_processed/a', f'ok_cells_{split}.npy')
    d = pickle.load(open(f, 'rb'))
    list_of_cells = d['list_of_cells']
    list_of_ome_filenames = d['list_of_ome_filenames']
    list_of_ome_indices = d['list_of_ome_indices']
    list_of_cell_ids = d['list_of_cell_ids']
    cell_is_ok = d['cell_is_ok']
    begin = list_of_ome_filenames.index(ome_filename)
    end = len(list_of_ome_filenames) - list_of_ome_filenames[::-1].index(ome_filename)
    # print(list_of_ome_filenames[begin])
    # print(list_of_ome_filenames[end - 1])
    # print(list_of_ome_filenames[end])
    # print(list_of_ome_filenames[end + 1])
    oks = cell_is_ok[begin: end]
    # labels = list_of_cell_ids[begin: end]
    # print(oks.shape)
    # print(t.shape)
    l0 = np.array(list(range(np.sum(oks).item())))
    l1 = list_of_cell_ids[ome_index][oks]
    assert len(l0) == len(l1)
    d = dict(zip(l0, l1))
    return d
In [24]:
# STEP 1: get all the expression values
from data import RawMeanDataset, TransformedMeanDataset, MasksDataset
index_info_omes, index_info_begins, index_info_ends = pickle.load(open(file_path('merged_cells_info.pickle'), 'rb'))

expressions = dict()
for method in methods:
    if method == 'raw':
        ds = RawMeanDataset('train')
        l = []
        for x in tqdm(ds, 'merging raw expresions'):
            l.append(x.numpy())
        expressions[method] = np.concatenate(l, axis=0)
    elif method == 'transformed':
        ds = TransformedMeanDataset('train')
        l = []
        for x in tqdm(ds, 'merging transformed expresions'):
            l.append(x.numpy())
        expressions[method] = np.concatenate(l, axis=0)
    elif method == 'vae_mu':
        f = os.path.join(file_path('vae_transformed_mean_dataset_LR_VB_S_0.0014685885989200848__3.8608662714605464e-08__False'), 'embedding_train.hdf5')
        with h5py.File(f, 'r') as f5:
            assert len(f5.keys()) == 1
            k, v = f5.items().__iter__().__next__()
            raw_ds = RawMeanDataset('train')
            o_train = raw_ds.filenames                
            original_list = []
            mu_list = []
            log_var_list = []
            for i, o in enumerate(tqdm(o_train, desc='accessing latent space')):
                original = raw_ds[i].clone().detach().numpy()
                mu = v[o]['mu'][...]
                log_var = v[o]['log_var'][...]
                assert mu.shape == log_var.shape
                assert len(original) == len(mu)
                original_list.append(original)
                mu_list.append(mu)
                log_var_list.append(log_var)
            expressions[method] = np.concatenate(l, axis=0)
    else:
        raise ValueError()
        



In [25]:
# STEP 2: compute the PCA, affine transform into [0, 1]
from sklearn.decomposition import PCA
pca = dict()
for method in methods:
    reducer = PCA(n_components=3)
    pca[method] = reducer.fit_transform(expressions[method])
    a = np.min(pca[method], axis=0)
    b = np.max(pca[method], axis=0)
    pca[method] = (pca[method] - a) / (b - a)
In [26]:
# STEP 4: in the loop below replace the mess with selecting from the PCA of everything
# STEP 5: fix the plotting function to plot the right things    
In [27]:
from pprint import pprint

q.top = dict()

for method in methods:
    labels = q.dbscan_labels[method]
    counts = dict(zip(*np.unique(labels, return_counts=True)))
    indices, sorted_labels = zip(*sorted(enumerate(labels), key=lambda x: (-counts[x[1]], x[1])))
#     print(sorted_labels)
    top_full = sorted(zip(*np.unique(labels, return_counts=True)), key=lambda x: -x[1])
    top = [t[0] for t in top_full]
    q.top[method] = top
    print(top[:3])
[39, 4, 29]
[7, 17, 3]
[6, 23, 4]
In [28]:
import math
masks_ds_train = MasksDataset('train')

# for method in methods:
method = 'vae_mu'
if True:
    for h in q.top[method][:8]:
        patients_per_hot_index = np.where(q.dbscan_labels[method] == h)[0]
        print(patients_per_hot_index)
        print('=' * 100)
        print(f'methods = {method}, hot_index = {h}')
        l = len(patients_per_hot_index)
        d = 4
        cols = 5
        rows = math.ceil(l / cols)
        fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(cols * d, rows * d))
        axes = axes.flatten()
        for i, ome_index in enumerate(tqdm(patients_per_hot_index, desc='plotting omes')):
#             print(f'ome_index = {ome_index}')
            begin = index_info_begins[ome_index]
            end = index_info_ends[ome_index]

            ome_filename = masks_ds_train.filenames[ome_index]
            masks = masks_ds_train[ome_index]

            d = get_filtered_labels_mapping(ome_index, ome_filename, 'train')

            new_masks = np.ones((masks.shape[0], masks.shape[1], 4))
            omitted_labels = set(list(range(np.max(masks)))).difference([dd.item() for dd in d.values()]).difference({0})
            for l in omitted_labels:
                new_masks[masks == l, :] = (0., 0., 0., 1.)
            new_masks[masks == 0, :] = (0., 0., 0., 1.)
#             kk = np.array([k for k in d.keys()])
#             vv = np.array([v.item() for v in d.values()])
#             # not working, but I should use of this kind
#             new_masks[masks == vv, :3] = pca[method][begin: end][kk, :]
            for k, v in d.items():
                new_masks[masks == v.item(), :3] = pca[method][begin: end][k, :]
    
            axes[i].imshow(new_masks)
            axes[i].set_title(f'ome_index = {ome_index}')
        for j in range(i, rows * cols):
            axes[j].axis('off')
        plt.suptitle(f'3-dim pca (computed globally) of transformed cell-level data, mapped into the RGB space')
        plt.tight_layout()
        plt.show()
[  7  10  25  36  41  47  80 100 134 140 141 143 160 194 199 202 205 206
 209]
====================================================================================================
methods = vae_mu, hot_index = 6

[ 40  45  54  74  86  96 104 112 154 192 203 204 216]
====================================================================================================
methods = vae_mu, hot_index = 23

[  4  12  28  49  50  73  76  77 108 148 149 166]
====================================================================================================
methods = vae_mu, hot_index = 4

[ 17  32  42  91 118 119 132 190 200 201]
====================================================================================================
methods = vae_mu, hot_index = 12

[ 19  22  43  55  75  78 139 158 163 179]
====================================================================================================
methods = vae_mu, hot_index = 14

[ 30  88 152 153 155 159 174 211 223]
====================================================================================================
methods = vae_mu, hot_index = 19

[  9  13  35  37  63  95  97 101]
====================================================================================================
methods = vae_mu, hot_index = 8

[ 21  57  66  84  93  94 177 191]
====================================================================================================
methods = vae_mu, hot_index = 16

In [29]:
import math
masks_ds_train = MasksDataset('train')

# for method in methods:
method = 'vae_mu'
if True:
    random_colors = np.random.rand(1000, 3)
    for h in q.top[method][:8]:
        patients_per_hot_index = np.where(q.dbscan_labels[method] == h)[0]
        print(patients_per_hot_index)
        print('=' * 100)
        print(f'methods = {method}, hot_index = {h}')
        l = len(patients_per_hot_index)
        d = 4
        cols = 5
        rows = math.ceil(l / cols)
        fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(cols * d, rows * d))
        axes = axes.flatten()
        for i, ome_index in enumerate(tqdm(patients_per_hot_index, desc='plotting omes')):
#             print(f'ome_index = {ome_index}')
            begin = index_info_begins[ome_index]
            end = index_info_ends[ome_index]

            ome_filename = masks_ds_train.filenames[ome_index]
            masks = masks_ds_train[ome_index]

            d = get_filtered_labels_mapping(ome_index, ome_filename, 'train')

            new_masks = np.ones((masks.shape[0], masks.shape[1], 4))
            omitted_labels = set(list(range(np.max(masks)))).difference([dd.item() for dd in d.values()]).difference({0})
            for l in omitted_labels:
                new_masks[masks == l, :] = (0., 0., 0., 1.)
            new_masks[masks == 0, :] = (0., 0., 0., 1.)
#             kk = np.array([k for k in d.keys()])
#             vv = np.array([v.item() for v in d.values()])
#             # not working, but I should use of this kind
#             new_masks[masks == vv, :3] = pca[method][begin: end][kk, :]
            for k, v in d.items():
                new_masks[masks == v.item(), :3] = random_colors[q.phenographs[method][begin: end][k, :]]
    
            axes[i].imshow(new_masks)
            axes[i].set_title(f'ome_index = {ome_index}')
        for j in range(i + 1, rows * cols):
            axes[j].axis('off')
        plt.suptitle(f'3-dim pca (computed globally) of transformed cell-level data, mapped into the RGB space')
        plt.tight_layout()
        plt.show()
[  7  10  25  36  41  47  80 100 134 140 141 143 160 194 199 202 205 206
 209]
====================================================================================================
methods = vae_mu, hot_index = 6

[ 40  45  54  74  86  96 104 112 154 192 203 204 216]
====================================================================================================
methods = vae_mu, hot_index = 23

[  4  12  28  49  50  73  76  77 108 148 149 166]
====================================================================================================
methods = vae_mu, hot_index = 4

[ 17  32  42  91 118 119 132 190 200 201]
====================================================================================================
methods = vae_mu, hot_index = 12

[ 19  22  43  55  75  78 139 158 163 179]
====================================================================================================
methods = vae_mu, hot_index = 14

[ 30  88 152 153 155 159 174 211 223]
====================================================================================================
methods = vae_mu, hot_index = 19

[  9  13  35  37  63  95  97 101]
====================================================================================================
methods = vae_mu, hot_index = 8

[ 21  57  66  84  93  94 177 191]
====================================================================================================
methods = vae_mu, hot_index = 16

In [30]:
import math
masks_ds_train = MasksDataset('train')

# for method in methods:
method = 'vae_mu'
if True:
    for h in q.top[method][:8]:
        patients_per_hot_index = np.where(q.dbscan_labels[method] == h)[0]
        print(patients_per_hot_index)
        print('=' * 100)
        print(f'methods = {method}, hot_index = {h}')
        l = len(patients_per_hot_index)
        d = 4
        cols = 5
        rows = math.ceil(l / cols)
        fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(cols * d, rows * d))
        axes = axes.flatten()
        for i, ome_index in enumerate(tqdm(patients_per_hot_index, desc='plotting omes')):
            f = q.all_fractions[method][ome_index, :]
            axes[i].bar(np.arange(len(f)), f, color=[random_colors[j] for j in np.arange(len(f))])
            axes[i].set_title(f'ome_index = {ome_index}')
        for j in range(i + 1, rows * cols):
            axes[j].axis('off')
        plt.suptitle(f'3-dim pca (computed globally) of transformed cell-level data, mapped into the RGB space')
        plt.tight_layout()
        plt.show()
[  7  10  25  36  41  47  80 100 134 140 141 143 160 194 199 202 205 206
 209]
====================================================================================================
methods = vae_mu, hot_index = 6

[ 40  45  54  74  86  96 104 112 154 192 203 204 216]
====================================================================================================
methods = vae_mu, hot_index = 23

[  4  12  28  49  50  73  76  77 108 148 149 166]
====================================================================================================
methods = vae_mu, hot_index = 4

[ 17  32  42  91 118 119 132 190 200 201]
====================================================================================================
methods = vae_mu, hot_index = 12

[ 19  22  43  55  75  78 139 158 163 179]
====================================================================================================
methods = vae_mu, hot_index = 14

[ 30  88 152 153 155 159 174 211 223]
====================================================================================================
methods = vae_mu, hot_index = 19

[  9  13  35  37  63  95  97 101]
====================================================================================================
methods = vae_mu, hot_index = 8

[ 21  57  66  84  93  94 177 191]
====================================================================================================
methods = vae_mu, hot_index = 16

In [31]:
m = q.all_fractions['vae_mu'].T
plt.matshow(m)
import scipy
import scipy.cluster.hierarchy as sch
# vector of pairwise distances
d = sch.distance.pdist(m)
L = sch.linkage(d, method='complete')
ind = sch.fcluster(L, 0.5*d.max(), 'distance')
ind
ii, cc = zip(*sorted(zip(range(len(ind)), ind), key=lambda x: x[1]))
ii = np.array(ii)
plt.matshow(m[ii, :])

m = m.T
d = sch.distance.pdist(m)
L = sch.linkage(d, method='complete')
ind = sch.fcluster(L, 0.5*d.max(), 'distance')
ind
ii, cc = zip(*sorted(zip(range(len(ind)), ind), key=lambda x: x[1]))
ii = np.array(ii)

plt.matshow(m[ii, :].T)
Out[31]:
<matplotlib.image.AxesImage at 0x7f27c6b84ee0>
In [32]:
import seaborn as sns
df = pd.DataFrame(m)
a = sns.clustermap(df)
In [33]:
b = a.data2d.index.to_numpy()
i = np.arange(len(b))
plt.plot(i, b)
plt.show()
plt.plot(i, cohort_info[b[i]], '-o')
plt.show()
In [ ]:
 
In [ ]: